The gadget package provides a framework for building interpretable, regionally-partitioned decision trees based on local feature effect estimates (such as ICE/PDP or ALE curves). The core workflow is as follows:
iml package) to compute local feature effects
(for ICE/PDP) for a fitted machine learning model.gadgetTree object and use its $fit() method to
recursively partition the data space, optimizing for regional
homogeneity in feature effects. Each node is represented by a
Node object.$plot()
and $plot_tree_structure() methods to visualize the partial
dependence or ICE behavior of features in each region of the tree, and
the tree topology and splits.$extract_split_info() method to summarize the split
criteria, node statistics, and regional effect heterogeneity for
interpretation and reporting.The package is modular and extensible: different effect strategies (e.g., partial dependence, accumulated local effects) can be implemented by extending the strategy interface. This design allows users to interpret complex black-box models by partitioning the feature space into regions with distinct, interpretable effect patterns.
set.seed(123)
n = 5000
x1 = runif(n, -1, 1)
x2 = runif(n, -1, 1)
x3 = runif(n, -1, 1)
y = ifelse(x3 > 0, 3 * x1, -3 * x1) + x3 + rnorm(n, sd = 0.3)
syn.data = data.frame(x1, x2, x3, y)
syn.task = TaskRegr$new("xor", backend = syn.data, target = "y")
syn.learner = lrn("regr.ranger")
syn.learner$train(syn.task)
syn.predictor = Predictor$new(syn.learner, data = syn.data[, c("x1", "x2", "x3")], y = syn.data$y)
syn.effect = FeatureEffects$new(syn.predictor, grid.size = 20, method = "ice")
syn.tree.pd = gadgetTree$new(strategy = pdStrategy$new(), n.split = 4, impr.par = 0.1, min.node.size = 1)
syn.tree.pd$fit(effect = syn.effect, data = syn.data, target.feature.name = "y")
syn.tree.pd$plot_tree_structure()
syn.esi.pd = syn.tree.pd$extract_split_info()
print(syn.esi.pd)
## id depth n.obs node.type split.feature split.value objective.value intImp
## 1 1 1 5000 root x3 0.0001516948 507271.287 0.9867277
## 2 2 2 2555 left <NA> NA 3871.077 NA
## 3 3 2 2445 right <NA> NA 2861.605 NA
## intImp.parent intImp.x1 intImp.x2 intImp.x3 split.feature.parent
## 1 NA 0.9904222 0.03722123 0.9975999 <NA>
## 2 0.9867277 NA NA NA x3
## 3 0.9867277 NA NA NA x3
## split.value.parent objective.value.parent is.final time
## 1 <NA> NA FALSE 0.027
## 2 0.000151694752275944 507271.3 TRUE NA
## 3 0.000151694752275944 507271.3 TRUE NA
syn.tree.pd$plot(syn.effect, syn.data, target.feature.name = "y",
show.plot = TRUE, show.point = FALSE, mean.center = TRUE)
### ALE Method
syn.tree.ale = gadgetTree$new(strategy = aleStrategy$new(), n.split = 3)
syn.tree.ale$fit(model = syn.learner, data = syn.data, target.feature.name = "y", n.intervals = 10)
syn.tree.ale$plot_tree_structure()
syn.esi.ale = syn.tree.ale$extract_split_info()
print(syn.esi.ale)
## id depth n.obs node.type split.feature split.value objective.value intImp
## 1 1 1 5000 root x3 -0.006593762 6084.6506 0.9248943
## 2 2 2 2536 left <NA> NA 223.5526 NA
## 3 3 2 2464 right <NA> NA 233.4393 NA
## intImp.parent intImp.x1 intImp.x2 intImp.x3 split.feature.parent
## 1 NA 0.9153289 0.01917008 0.9845032 <NA>
## 2 0.9248943 NA NA NA x3
## 3 0.9248943 NA NA NA x3
## split.value.parent objective.value.parent is.final time
## 1 NA NA FALSE 0.406
## 2 -0.006593762 6084.651 TRUE NA
## 3 -0.006593762 6084.651 TRUE NA
#object_size(syn.tree.ale)
boxplot(time ~ depth, data = syn.esi.pd, main = "Distribution of split time per depth - Syn.PD")
boxplot(time ~ depth, data = syn.esi.ale, main = "Distribution of split time per depth - Syn.ALE")
boxplot(time ~ depth, data = bike.esi.pd, main = "Distribution of split time per depth - Bike.PD")
boxplot(time ~ depth, data = bike.esi.ale, main = "Distribution of split time per depth - Bike.ALE")
set.seed(1)
options(future.globals.maxSize = 4 * 1024 * 1024^2) # 4GB
plan(sequential)
datagen_p5 = function(n, seed = 1) {
set.seed(seed)
x1 = round(runif(n, -1, 1), 1)
x2 = round(runif(n, -1, 1), 3)
x3 = as.factor(sample(c(0, 1), size = n, replace = TRUE, prob = c(0.5, 0.5)))
x4 = sample(c(0, 1), size = n, replace = TRUE, prob = c(0.7, 0.3))
x5 = sample(c(0, 1), size = n, replace = TRUE, prob = c(0.5, 0.5))
dat = data.frame(x1, x2, x3, x4, x5)
y = 0.2 * x1 - 8 * x2 + ifelse(x3 == 0, 16 * x2, 0) + ifelse(x1 > 0, 8 * x2, 0)
eps = rnorm(n, 0, 0.1 * sd(y))
y = y + eps
dat$y = y
X = dat[, setdiff(colnames(dat), "y")]
mod = ranger(y ~ ., data = dat, num.trees = 100)
pred = function(model, newdata) predict(model, newdata)$predictions
model = Predictor$new(mod, data = X, y = dat$y, predict.function = pred)
eff = FeatureEffects$new(model, method = "ice", grid.size = 20)
list(dat = dat, eff = eff)
}
datagen_p10 = function(n, seed = 1) {
set.seed(seed)
x1 = round(runif(n, -1, 1), 1)
x2 = round(runif(n, -1, 1), 3)
x3 = as.factor(sample(c(0, 1), size = n, replace = TRUE, prob = c(0.5, 0.5)))
x4 = sample(c(0, 1), size = n, replace = TRUE, prob = c(0.7, 0.3))
x5 = sample(c(0, 1), size = n, replace = TRUE, prob = c(0.5, 0.5))
x6 = rnorm(n, mean = 1, sd = 5)
x7 = round(rnorm(n, mean = 10, sd = 10), 2)
x8 = round(rnorm(n, mean = 100, sd = 15), 4)
x9 = round(rnorm(n, mean = 1000, sd = 20), 7)
x10 = rnorm(n, mean = 10000, sd = 25)
dat = data.frame(x1, x2, x3, x4, x5, x6, x7, x8, x9, x10)
y = 0.2 * x1 - 8 * x2 + ifelse(x3 == 0, 16 * x2, 0) + ifelse(x1 > 0, 8 * x2, 0) + 2 * x8
eps = rnorm(n, 0, 0.1 * sd(y))
y = y + eps
dat$y = y
X = dat[, setdiff(colnames(dat), "y")]
mod = ranger(y ~ ., data = dat, num.trees = 100)
pred = function(model, newdata) predict(model, newdata)$predictions
model = Predictor$new(mod, data = X, y = dat$y, predict.function = pred)
eff = FeatureEffects$new(model, method = "ice", grid.size = 20)
list(dat = dat, eff = eff)
}
datagen_p20 = function(n, seed = 1) {
set.seed(seed)
x1 = round(runif(n, -1, 1), 1)
x2 = round(runif(n, -1, 1), 3)
x3 = as.factor(sample(c(0, 1), size = n, replace = TRUE, prob = c(0.5, 0.5)))
x4 = sample(c(0, 1), size = n, replace = TRUE, prob = c(0.7, 0.3))
x5 = sample(c(0, 1), size = n, replace = TRUE, prob = c(0.5, 0.5))
x6 = rnorm(n, mean = 1, sd = 5)
x7 = round(rnorm(n, mean = 10, sd = 10), 2)
x8 = round(rnorm(n, mean = 100, sd = 15), 4)
x9 = round(rnorm(n, mean = 1000, sd = 20), 7)
x10 = rnorm(n, mean = 10000, sd = 25)
noise = replicate(10, rnorm(n), simplify = FALSE)
names(noise) = paste0("noise", 1:10)
dat = data.frame(x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, noise)
y = 0.2 * x1 - 8 * x2 + ifelse(x3 == 0, 16 * x2, 0) + ifelse(x1 > 0, 8 * x2, 0) + 2 * x8
eps = rnorm(n, 0, 0.1 * sd(y))
y = y + eps
dat$y = y
X = dat[, setdiff(colnames(dat), "y")]
mod = ranger(y ~ ., data = dat, num.trees = 100)
pred = function(model, newdata) predict(model, newdata)$predictions
model = Predictor$new(mod, data = X, y = dat$y, predict.function = pred)
eff = FeatureEffects$new(model, method = "ice", grid.size = 20)
list(dat = dat, eff = eff)
}
n_list = c(1000, 5000, 10000)
p_list = c(5, 10, 20)
bench_results = list()
tree_sizes = data.frame(n = integer(), p = integer(), tree_size_MB = numeric(),
mem_before_MB = numeric(), mem_after_MB = numeric(), mem_increase_MB = numeric())
# Initial memory cleanup
gc()
## used (Mb) gc trigger (Mb) limit (Mb) max used (Mb)
## Ncells 2899530 154.9 5215075 278.6 NA 5215075 278.6
## Vcells 14070167 107.4 40000535 305.2 32768 39999922 305.2
initial_mem = gc()["Vcells", "used"]
for (n in n_list) {
for (p in p_list) {
cat(sprintf("Running: n = %d, p = %d\n", n, p))
# Clean memory and record starting state
gc()
mem_before = gc()["Vcells", "used"]
# Data generation
if (p == 5) {
sim = datagen_p5(n)
} else if (p == 10) {
sim = datagen_p10(n)
} else if (p == 20) {
sim = datagen_p20(n)
}
# Clean memory after data generation
gc()
# Create and fit tree
tree = gadgetTree$new(strategy = pdStrategy$new(), n.split = 10)
tree$fit(effect = sim$eff, data = sim$dat, target.feature.name = "y")
# Clean memory after tree fitting
gc()
# Calculate tree size and memory usage
tree_size_MB = as.numeric(pryr::object_size(tree)) / 1024^2
mem_after = gc()["Vcells", "used"]
mem_increase = mem_after - mem_before
# Record results
tree_sizes = rbind(tree_sizes, data.frame(
n = n,
p = p,
tree_size_MB = tree_size_MB,
mem_before_MB = mem_before,
mem_after_MB = mem_after,
mem_increase_MB = mem_increase
))
# Clean up tree object
rm(tree)
gc()
# Benchmark with memory monitoring
res = bench::mark(
fit = {
# Clean memory
gc()
tree = gadgetTree$new(strategy = pdStrategy$new(), n.split = 10)
tree$fit(effect = sim$eff, data = sim$dat, target.feature.name = "y")
# Clean memory
gc()
},
iterations = 5
)
res$n = n
res$p = p
bench_results[[paste0("n", n, "_p", p)]] = res
# Clean up sim data
rm(sim)
gc()
cat(sprintf("Memory used: %.2f MB, Tree size: %.2f MB\n", mem_increase, tree_size_MB))
}
}
## Running: n = 1000, p = 5
## Memory used: 522373.00 MB, Tree size: 3.49 MB
## Running: n = 1000, p = 10
## Memory used: 1131087.00 MB, Tree size: 4.09 MB
## Running: n = 1000, p = 20
## Memory used: 2054541.00 MB, Tree size: 4.12 MB
## Running: n = 5000, p = 5
## Memory used: 2302209.00 MB, Tree size: 3.54 MB
## Running: n = 5000, p = 10
## Memory used: 5585787.00 MB, Tree size: 5.58 MB
## Running: n = 5000, p = 20
## Memory used: 10170724.00 MB, Tree size: 5.93 MB
## Running: n = 10000, p = 5
## Memory used: 4357172.00 MB, Tree size: 3.60 MB
## Running: n = 10000, p = 10
## Memory used: 11148419.00 MB, Tree size: 7.05 MB
## Running: n = 10000, p = 20
## Memory used: 20286252.00 MB, Tree size: 6.07 MB
# Final memory cleanup
gc()
## used (Mb) gc trigger (Mb) limit (Mb) max used (Mb)
## Ncells 3653258 195.2 6298091 336.4 NA 6298091 336.4
## Vcells 53709992 409.8 120246076 917.5 32768 120245385 917.4
final_mem = gc()["Vcells", "used"]
cat(sprintf("Total memory increase: %.2f MB\n", final_mem - initial_mem))
## Total memory increase: 39677448.00 MB
bench_all = do.call(rbind, bench_results)
n_vec = rep(bench_all$n, each = 5)
p_vec = rep(bench_all$p, each = 5)
time_vec = unlist(bench_all$time)
time_ms = as.numeric(time_vec) * 1000
bench_long = data.frame(
n = n_vec,
p = p_vec,
time_ms = time_ms
)
ggplot(bench_long, aes(x = factor(n), y = time_ms, color = factor(p), group = p)) +
geom_boxplot(aes(group = interaction(n, p))) +
geom_jitter(width = 0.2, alpha = 0.5) +
labs(x = "Sample Size (n)", y = "Fit Time (ms)", color = "Feature Number (p)",
title = "gadgetTree$fit(n.split = 10) Benchmark: Varying n and p") +
theme_minimal()
tree_sizes
## n p tree_size_MB mem_before_MB mem_after_MB mem_increase_MB
## 1 1000 5 3.492287 14034366 14556739 522373
## 2 1000 10 4.094444 14146699 15277786 1131087
## 3 1000 20 4.116432 16188294 18242835 2054541
## 4 5000 5 3.538010 19804925 22107134 2302209
## 5 5000 10 5.582108 19894630 25480417 5585787
## 6 5000 20 5.927483 24622692 34793416 10170724
## 7 10000 5 3.595230 35806504 40163676 4357172
## 8 10000 10 7.049728 35852011 47000430 11148419
## 9 10000 20 6.073471 43342535 63628787 20286252
\[ \begin{aligned} SSE &= \sum_{i=1}^n(y_i-\bar{y})^2\\ &= \sum_{i=1}^ny_i^2-2\bar{y}\sum_{i=1}^ny_i+\sum^n\bar{y}^2\\ &= \sum_{i=1}^ny_i^2-2\bar{y}n\bar{y}+n\bar{y}^2\\ &= \sum_{i=1}^ny_i^2 - n\bar{y}^2\\ &= \sum_{i=1}^ny_i^2-n(\frac{\sum_{i=1}^ny_i}{n})^2\\ &= \sum_{i=1}^ny_i^2-\frac{1}{n}(\sum_{i=1}^ny_i)^2\\ &= SS -\frac{S^2}{n} \end{aligned} \] \[ \begin{aligned} SSE_{Reduction} &= SSE_{parent}-SSE_{left}-SSE_{right}\\ &= SS_{parent} -\frac{S_{parent}^2}{n_{parent}}-(SS_{left} -\frac{S_{left}^2}{n_{left}})-(SS_{right} -\frac{S_{right}^2}{n_{right}}) \end{aligned} \] Since \[n\_{parent} = n\_{left} + n\_{right}\\ SS\_{parent} = SS\_{left} + SS\_{right}\]
Then \[ SSE\_{Reduction} = -\frac{S_{parent}^2}{n_{parent}} +\frac{S_{left}^2}{n_{left}} +\frac{S_{right}^2}{n_{right}}\\ max(SSE\_{Reduction}) = max(\frac{S_{left}^2}{n_{left}} +\frac{S_{right}^2}{n_{right}})=min(-\frac{S_{left}^2}{n_{left}} -\frac{S_{right}^2}{n_{right}}) \]